{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Notebook for Pre-trained Knowledge Graph Embedding (on UMLS)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "## Import packages\n",
    "\n",
    "import csv\n",
    "import pandas as pd\n",
    "import torch\n",
    "import numpy as np\n",
    "from tqdm import tqdm\n",
    "from collections import defaultdict\n",
    "from pyhealth.medcode import ICD9CM, ICD9PROC\n",
    "from pyhealth.medcode.pretrained_embeddings.kg_emb.models import TransE, RotatE, ComplEx, DistMult\n",
    "from pyhealth.medcode.pretrained_embeddings.kg_emb.datasets import UMLSDataset, split\n",
    "from pyhealth.medcode.pretrained_embeddings.kg_emb.tasks import link_prediction_fn"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load Pre-trained KGE model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO: Pandarallel will run on 64 workers.\n",
      "INFO: Pandarallel will use Memory file system to transfer data between the main process and workers.\n",
      "Loading UMLS knowledge graph...\n",
      "Processing UMLS knowledge graph...\n",
      "Building UMLS knowledge graph...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 43842950/43842950 [00:29<00:00, 1486771.41it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Statistics of base dataset (dev=False):\n",
      "\t- Dataset: UMLSDataset\n",
      "\t- Number of triples: 43842950\n",
      "\t- Number of entities: 3110571\n",
      "\t- Number of relations: 965\n",
      "\t- Task name: Null\n",
      "\t- Number of samples: 0\n",
      "\n",
      "None\n",
      "Relations in KG: {'RB': 0, 'translation_of': 1, 'permuted_term_of': 2, 'SY': 3, 'AQ': 4, 'PAR': 5, 'mapped_to': 6, 'associated_with': 7, 'has_permuted_term': 8, 'has_translation': 9, 'has_transliterated_form': 10, 'measures': 11, 'parent_of': 12, 'form_of': 13, 'CHD': 14, 'has_component': 15, 'transliterated_form_of': 16, 'RO': 17, 'RN': 18, 'inverse_isa': 19, 'disposition_of': 20, 'exhibited_by': 21, 'see_from': 22, 'see': 23, 'entry_combination_of': 24, 'mapped_from': 25, 'has_causative_agent': 26, 'used_for': 27, 'use': 28, 'isa': 29, 'subset_includes_concept': 30, 'has_direct_substance': 31, 'has_ingredient': 32, 'has_tradename': 33, 'mapping_qualifier_of': 34, 'contains': 35, 'active_ingredient_of': 36, 'has_active_ingredient': 37, 'has_active_moiety': 38, 'has_member': 39, 'has_basis_of_strength_substance': 40, 'has_precise_active_ingredient': 41, 'tradename_of': 42, 'may_be_prevented_by': 43, 'may_be_treated_by': 44, 'has_contraindicated_drug': 45, 'has_part': 46, 'physiologic_effect_of': 47, 'mechanism_of_action_of': 48, 'therapeutic_class_of': 49, 'lab_number_of': 50, 'chemotherapy_regimen_has_component': 51, 'has_gdc_value': 52, 'RQ': 53, 'structural_class_of': 54, 'has_inactive_ingredient': 55, 'contraindicated_class_of': 56, 'has_free_acid_or_base_form': 57, 'has_contraindicated_class': 58, 'is_biochemical_function_of_gene_product': 59, 'gene_product_has_chemical_classification': 60, 'is_target': 61, 'process_acts_on': 62, 'has_form': 63, 'has_precise_ingredient': 64, 'has_salt_form': 65, 'has_modification': 66, 'metabolic_site_of': 67, 'has_active_metabolites': 68, 'role_played_by': 69, 'is_physiologic_effect_of_chemical_or_drug': 70, 'is_mechanism_of_action_of_chemical_or_drug': 71, 'is_modification_of': 72, 'has_expanded_form': 73, 'expanded_form_of': 74, 'part_of': 75, 'adjectival_form_of': 76, 'has_location': 77, 'noun_form_of': 78, 'related_part': 79, 'has_regional_part': 80, 'regional_part_of': 81, 'constitutional_part_of': 82, 'analyzes': 83, 'has_answer': 84, 'has_imaged_location': 85, 'has_finding_site': 86, 'has_direct_procedure_site': 87, 'has_procedure_site': 88, 'has_indirect_procedure_site': 89, 'entire_anatomy_structure_of': 90, 'has_specimen_source_topography': 91, 'has_system': 92, 'contraindicated_with_disease': 93, 'location_of': 94, 'primary_mapped_to': 95, 'classified_as': 96, 'same_as': 97, 'classifies': 98, 'default_inpatient_classification_of': 99, 'default_outpatient_classification_of': 100, 'finding_site_of': 101, 'has_focus': 102, 'has_defining_characteristic': 103, 'has_manifestation': 104, 'ssc': 105, 'clinically_associated_with': 106, 'associated_morphology_of': 107, 'has_associated_finding': 108, 'co-occurs_with': 109, 'disease_may_have_finding': 110, 'may_treat': 111, 'has_multi_level_category': 112, 'has_single_level_category': 113, 'related_to': 114, 'evaluation_of': 115, 'is_seronet_permissible_value_for_variable': 116, 'primary_mapped_from': 117, 'member_of': 118, 'has_constitutional_part': 119, 'has_physical_part_of_anatomic_structure': 120, 'is_location_of_anatomic_structure': 121, 'has_nerve_supply': 122, 'anatomic_structure_is_physical_part_of': 123, 'disease_has_associated_anatomic_site': 124, 'gene_found_in_organism': 125, 'manifestation_of': 126, 'inheritance_type_of': 127, 'alias_of': 128, 'has_alias': 129, 'has_phenotype': 130, 'due_to': 131, 'mth_has_british_form': 132, 'mth_british_form_of': 133, 'has_associated_morphology': 134, 'has_direct_morphology': 135, 'has_risk_factor': 136, 'is_interpreted_by': 137, 'pathological_process_of': 138, 'has_occurrence': 139, 'disease_excludes_finding': 140, 'disease_has_finding': 141, 'has_pcdc_hl_permissible_value': 142, 'QB': 143, 'has_mapping_qualifier': 144, 'may_prevent': 145, 'has_entry_combination': 146, 'interpretation_of': 147, 'induces': 148, 'associated_finding_of': 149, 'temporal_context_of': 150, 'finding_context_of': 151, 'subject_relationship_context_of': 152, 'cause_of': 153, 'clinical_course_of': 154, 'occurs_after': 155, 'occurs_before': 156, 'ddx': 157, 'intent_of': 158, 'direct_procedure_site_of': 159, 'other_mapped_to': 160, 'other_mapped_from': 161, 'is_associated_anatomic_site_of': 162, 'is_finding_of_disease': 163, 'is_primary_anatomic_site_of_disease': 164, 'is_object_guidance_for': 165, 'has_indirect_morphology': 166, 'has_pathology': 167, 'has_possibly_included_pathology': 168, 'has_specimen_source_morphology': 169, 'has_procedure_morphology': 170, 'finding_method_of': 171, 'causative_agent_of': 172, 'gene_associated_with_disease': 173, 'branch_of': 174, 'interprets': 175, 'is_location_of_biological_process': 176, 'has_imaging_focus': 177, 'laterality_of': 178, 'has_device_intended_site': 179, 'has_possibly_included_procedure_site': 180, 'has_pcdc_ews_permissible_value': 181, 'has_structural_class': 182, 'pharmacokinetics_of': 183, 'enzyme_metabolizes_chemical_or_drug': 184, 'has_contraindicated_physiologic_effect': 185, 'has_contraindicated_mechanism_of_action': 186, 'uses_substance': 187, 'dose_form_of': 188, 'has_divisor': 189, 'gene_product_plays_role_in_biological_process': 190, 'gene_plays_role_in_process': 191, 'process_includes_biological_process': 192, 'may_be_diagnosed_by': 193, 'may_diagnose': 194, 'is_normal_tissue_origin_of_disease': 195, 'gene_encodes_gene_product': 196, 'biological_process_involves_gene_product': 197, 'process_initiates_biological_process': 198, 'clinically_similar': 199, 'has_parent': 200, 'method_of': 201, 'has_branch': 202, 'disease_has_primary_anatomic_site': 203, 'is_normal_cell_origin_of_disease': 204, 'disease_may_have_associated_disease': 205, 'disease_has_associated_disease': 206, 'has_seronet_permissible_value': 207, 'entry_term_of': 208, 'gene_mapped_to_disease': 209, 'part_anatomy_structure_of': 210, 'is_associated_anatomy_of_gene_product': 211, 'gene_product_has_biochemical_function': 212, 'is_structural_domain_or_motif_of_gene_product': 213, 'negatively_regulates': 214, 'positively_regulates': 215, 'regulates': 216, 'modifies': 217, 'device_used_by': 218, 'default_mapped_from': 219, 'default_mapped_to': 220, 'uniquely_mapped_from': 221, 'uniquely_mapped_to': 222, 'mth_has_plain_text_form': 223, 'mth_has_xml_form': 224, 'mth_plain_text_form_of': 225, 'mth_xml_form_of': 226, 'is_not_finding_of_disease': 227, 'is_abnormal_cell_of_disease': 228, 'gene_product_malfunction_associated_with_disease': 229, 'eo_disease_maps_to_human_disease': 230, 'is_not_abnormal_cell_of_disease': 231, 'has_pcdc_gct_permissible_value': 232, 'is_not_primary_anatomic_site_of_disease': 233, 'excised_anatomy_has_procedure': 234, 'target_anatomy_has_procedure': 235, 'has_associated_procedure': 236, 'procedure_has_excised_anatomy': 237, 'procedure_has_partially_excised_anatomy': 238, 'may_be_finding_of_disease': 239, 'is_not_normal_cell_origin_of_disease': 240, 'may_be_molecular_abnormality_of_disease': 241, 'gene_involved_in_pathogenesis_of_disease': 242, 'may_be_cytogenetic_abnormality_of_disease': 243, 'is_not_normal_tissue_origin_of_disease': 244, 'has_given_pharmaceutical_substance': 245, 'has_challenge': 246, 'is_chemical_classification_of_gene_product': 247, 'biological_process_has_result_chemical_or_drug': 248, 'has_target': 249, 'uses_device': 250, 'has_possibly_included_procedure_device': 251, 'has_direct_device': 252, 'has_entry_term': 253, 'precise_ingredient_of': 254, 'disease_has_normal_tissue_origin': 255, 'gene_product_expressed_in_tissue': 256, 'diagnostic_criteria_of': 257, 'has_diagnostic_criteria': 258, 'direct_substance_of': 259, 'indirect_procedure_site_of': 260, 'route_of_administration_of': 261, 'has_route_of_administration': 262, 'has_at_risk_population': 263, 'occurs_in': 264, 'has_method': 265, 'disease_excludes_normal_tissue_origin': 266, 'disease_has_metastatic_anatomic_site': 267, 'has_inherent_location': 268, 'disease_excludes_primary_anatomic_site': 269, 'has_pcdc_os_permissible_value': 270, 'procedure_may_have_excised_anatomy': 271, 'has_ctdc_value': 272, 'gene_product_has_associated_anatomy': 273, 'is_associated_disease_of': 274, 'has_specimen_procedure': 275, 'associated_procedure_of': 276, 'has_patient_type': 277, 'has_recipient_category': 278, 'place_traveled_to': 279, 'has_conceptual_part': 280, 'conceptual_part_of': 281, 'has_associated_observation': 282, 'component_of': 283, 'has_measurement_method': 284, 'biological_process_results_from_biological_process': 285, 'is_not_cytogenetic_abnormality_of_disease': 286, 'is_not_molecular_abnormality_of_disease': 287, 'uses_possibly_included_substance': 288, 'has_specimen_substance': 289, 'has_possibly_included_associated_finding': 290, 'has_related_factor': 291, 'acted_on_by_process': 292, 'property_of': 293, 'characterized_by': 294, 'has_associated_condition': 295, 'is_value_for_gdc_property': 296, 'has_sign_or_symptom': 297, 'has_dose_form': 298, 'tissue_is_expression_site_of_gene_product': 299, 'gene_product_affected_by_chemical_or_drug': 300, 'has_clinician_form': 301, 'icd_dagger': 302, 'life_circumstance_of': 303, 'disease_has_normal_cell_origin': 304, 'clinician_form_of': 305, 'consumer_friendly_form_of': 306, 'specimen_of': 307, 'specialty_of': 308, 'do_not_code_with': 309, 'has_consumer_friendly_form': 310, 'active_metabolites_of': 311, 'substance_used_by': 312, 'procedure_site_of': 313, 'derives': 314, 'is_pcdc_ews_permissible_value_for_variable': 315, 'is_pcdc_os_permissible_value_for_variable': 316, 'pcdc_data_type_of': 317, 'partially_excised_anatomy_has_procedure': 318, 'completely_excised_anatomy_has_procedure': 319, 'has_pcdc_aml_permissible_value': 320, 'receives_input_from': 321, 'sends_output_to': 322, 'projects_from': 323, 'arterial_supply_of': 324, 'is_exam_for': 325, 'focus_of': 326, 'chemical_or_drug_initiates_biological_process': 327, 'biological_process_is_part_of_process': 328, 'loinc_number_of': 329, 'parent_group_of': 330, 'has_property': 331, 'has_default_inpatient_classification': 332, 'has_default_outpatient_classification': 333, 'multiply_mapped_to': 334, 'has_british_form': 335, 'british_form_of': 336, 'may_be_abnormal_cell_of_disease': 337, 'is_molecular_abnormality_of_disease': 338, 'associated_procedure_of_excluded': 339, 'associated_procedure_of_possibly_included': 340, 'during': 341, 'disease_has_cytogenetic_abnormality': 342, 'disease_may_have_molecular_abnormality': 343, 'disease_may_have_cytogenetic_abnormality': 344, 'direct_morphology_of': 345, 'has_excluded_associated_finding': 346, 'has_excluded_pathology': 347, 'lymphatic_drainage_of': 348, 'inheres_in': 349, 'has_pcdc_all_permissible_value': 350, 'has_therapeutic_class': 351, 'has_coating_material': 352, 'biological_process_has_initiator_chemical_or_drug': 353, 'mth_expanded_form_of': 354, 'mth_has_expanded_form': 355, 'has_suffix': 356, 'chemical_or_drug_has_physiologic_effect': 357, 'allele_plays_role_in_metabolism_of_chemical_or_drug': 358, 'biological_process_has_associated_location': 359, 'is_organism_source_of_gene_product': 360, 'is_pcdc_aml_permissible_value_for_variable': 361, 'is_pcdc_gct_permissible_value_for_variable': 362, 'is_pcdc_hl_permissible_value_for_variable': 363, 'procedure_has_completely_excised_anatomy': 364, 'defining_characteristic_of': 365, 'population_at_risk_for': 366, 'related_factor_of': 367, 'actual_outcome_of': 368, 'expected_outcome_of': 369, 'is_dipg_dmg_permissible_value_for_variable': 370, 'anterior_to': 371, 'continuous_proximally_with': 372, 'continuous_with': 373, 'distal_to': 374, 'proximal_to': 375, 'lateral_to': 376, 'anterolateral_to': 377, 'direct_right_of': 378, 'superior_to': 379, 'surrounds': 380, 'direct_left_of': 381, 'inferior_to': 382, 'medial_to': 383, 'posterior_to': 384, 'anatomic_structure_has_location': 385, 'receives_attachment_from': 386, 'complex_has_physical_part': 387, 'nerve_supply_of': 388, 'venous_drainage_of': 389, 'receives_drainage_from': 390, 'develops_into': 391, 'bounds': 392, 'related_object': 393, 'has_excluded_procedure_site': 394, 'direct_device_of': 395, 'adheres_to': 396, 'has_dipg_dmg_permissible_value': 397, 'gene_product_is_physical_part_of': 398, 'homonym_for': 399, 'procedure_has_target_anatomy': 400, 'has_finding_method': 401, 'anteroinferior_to': 402, 'attaches_to': 403, 'induced_by': 404, 'has_pharmaceutical_route': 405, 'has_maneuver_type': 406, 'realization_of': 407, 'measured_by': 408, 'finding_informer_of': 409, 'special_category_includes_neoplasm': 410, 'associated_genetic_condition': 411, 'may_be_associated_disease_of_disease': 412, 'genetic_biomarker_related_to': 413, 'articulates_with': 414, 'has_view_type': 415, 'has_direct_site': 416, 'tributary_of': 417, 'has_tributary': 418, 'drains_into': 419, 'connection_type_of': 420, 'chemical_or_drug_affects_cell_type_or_tissue': 421, 'biological_process_has_result_anatomy': 422, 'gene_product_has_organism_source': 423, 'uses': 424, 'is_action_guidance_for': 425, 'manufactures': 426, 'chemical_or_drug_affects_gene_product': 427, 'chromosome_mapped_to_disease': 428, 'ingredient_of': 429, 'has_icdc_value': 430, 'biomarker_type_includes_gene_product': 431, 'is_marked_by_gene_product': 432, 'multiply_mapped_from': 433, 'has_process_output': 434, 'gene_product_is_biomarker_type': 435, 'has_compositional_material': 436, 'used_by': 437, 'diagnosed_by': 438, 'has_possibly_included_method': 439, 'has_excluded_method': 440, 'inverse_during': 441, 'has_inherent_attribute': 442, 'time_aspect_of': 443, 'continuous_distally_with': 444, 'associated_disease': 445, 'homonym_of': 446, 'sign_or_symptom_of': 447, 'has_class': 448, 'has_specimen_source_identity': 449, 'scale_type_of': 450, 'access_of': 451, 'has_panel_element': 452, 'has_procedure_device': 453, 'has_indirect_device': 454, 'concept_in_subset': 455, 'has_possibly_included_component': 456, 'add_on_code_for': 457, 'is_grade_of_disease': 458, 'is_stage_of_disease': 459, 'relative_to_part_of': 460, 'energy_used_by': 461, 'gene_product_is_biomarker_of': 462, 'has_arterial_supply': 463, 'procedure_context_of': 464, 'specimen_source_topography_of': 465, 'specimen_procedure_of': 466, 'has_specimen': 467, 'uses_excluded_substance': 468, 'has_supersystem': 469, 'relative_to': 470, 'pharmaceutical_state_of_matter_of': 471, 'pharmaceutical_basic_dose_form_of': 472, 'state_of_matter_of': 473, 'has_basic_dose_form': 474, 'referred_to_by': 475, 'regimen_has_accepted_use_for_disease': 476, 'associated_condition_of': 477, 'priority_of': 478, 'gene_is_biomarker_of': 479, 'has_print_name': 480, 'print_name_of': 481, 'is_pcdc_all_permissible_value_for_variable': 482, 'exhibits': 483, 'gene_product_is_element_in_pathway': 484, 'gene_is_element_in_pathway': 485, 'biological_process_has_initiator_process': 486, 'biological_process_has_result_biological_process': 487, 'surrounded_by': 488, 'has_entire_anatomy_structure': 489, 'has_subject_relationship_context': 490, 'has_finding_informer': 491, 'prev_name_of': 492, 'prev_symbol_of': 493, 'has_prev_name': 494, 'has_prev_symbol': 495, 'access_device_used_by': 496, 'method_of_possibly_included': 497, 'role_has_domain': 498, 'gene_has_abnormality': 499, 'is_ctdc_value_of': 500, 'role_has_range': 501, 'disease_mapped_to_chromosome': 502, 'cytogenetic_abnormality_involves_chromosome': 503, 'gene_in_chromosomal_location': 504, 'has_lab_number': 505, 'associated_with_malfunction_of_gene_product': 506, 'characterizes': 507, 'has_interpretation': 508, 'pathology_of': 509, 'recipient_category_of': 510, 'has_result': 511, 'anatomy_originated_from_biological_process': 512, 'specimen_substance_of': 513, 'has_context_binding': 514, 'supported_concept_property_in': 515, 'supported_concept_relationship_in': 516, 'context_binding_of': 517, 'icd_asterisk': 518, 'disease_may_have_normal_cell_origin': 519, 'has_projection': 520, 'receives_projection': 521, 'has_venous_drainage': 522, 'class_of': 523, 'analyzed_by': 524, 'system_of': 525, 'scale_of': 526, 'uses_access_device': 527, 'has_gene_product_element': 528, 'effect_may_be_inhibited_by': 529, 'procedure_device_of': 530, 'has_modality_subtype': 531, 'secondary_segmental_supply_of': 532, 'segmental_supply_of': 533, 'primary_segmental_supply_of': 534, 'anteromedial_to': 535, 'has_definitional_manifestation': 536, 'has_adherent': 537, 'severity_of': 538, 'derives_from': 539, 'has_excluded_procedure_device': 540, 'uses_energy': 541, 'has_germ_origin': 542, 'chemical_or_drug_is_product_of_biological_process': 543, 'disease_excludes_normal_cell_origin': 544, 'adjacent_to': 545, 'procedure_has_imaged_anatomy': 546, 'matures_into': 547, 'surgical_extent_of': 548, 'method_of_excluded': 549, 'has_evaluation': 550, 'disease_excludes_metastatic_anatomic_site': 551, 'pharmaceutical_intended_site_of': 552, 'pharmaceutical_transformation_of': 553, 'pharmaceutical_administration_method_of': 554, 'pharmaceutical_release_characteristics_of': 555, 'basic_dose_form_of': 556, 'dose_form_intended_site_of': 557, 'dose_form_administration_method_of': 558, 'dose_form_release_characteristic_of': 559, 'dose_form_transformation_of': 560, 'device_intended_site_of': 561, 'transforms_into': 562, 'has_subject': 563, 'biological_process_involves_chemical_or_drug': 564, 'may_inhibit_effect_of': 565, 'disease_has_abnormal_cell': 566, 'disease_may_have_abnormal_cell': 567, 'connected_to': 568, 'procedure_may_have_completely_excised_anatomy': 569, 'has_pharmaceutical_state_of_matter': 570, 'has_state_of_matter': 571, 'procedure_morphology_of': 572, 'regulated_by': 573, 'chemical_or_drug_has_mechanism_of_action': 574, 'organism_has_gene': 575, 'process_involves_gene': 576, 'gene_product_encoded_by_gene': 577, 'related_to_genetic_biomarker': 578, 'disease_mapped_to_gene': 579, 'molecular_abnormality_involves_gene': 580, 'is_chromosomal_location_of_gene': 581, 'pathogenesis_of_disease_involves_gene': 582, 'disease_excludes_abnormal_cell': 583, 'is_cytogenetic_abnormality_of_disease': 584, 'associated_observation_of': 585, 'owning_affiliate_of': 586, 'is_related_to_endogenous_product': 587, 'process_output_of': 588, 'procedure_site_of_possibly_included': 589, 'posterosuperior_to': 590, 'temporally_related_to': 591, 'is_abnormality_of_gene_product': 592, 'refers_to': 593, 'gene_product_variant_of_gene_product': 594, 'has_locale': 595, 'has_clinical_course': 596, 'surgical_approach_of': 597, 'gene_product_has_structural_domain_or_motif': 598, 'has_technique': 599, 'has_mechanism_of_action': 600, 'has_excluded_patient_type': 601, 'possibly_equivalent_to': 602, 'indirect_morphology_of': 603, 'has_precondition': 604, 'has_pharmaceutical_intended_site': 605, 'before': 606, 'is_metastatic_anatomic_site_of_disease': 607, 'has_modality_type': 608, 'projects_to': 609, 'excised_anatomy_may_have_procedure': 610, 'class_code_classified_by': 611, 'owning_subsection_of': 612, 'happens_during': 613, 'developmental_stage_of': 614, 'has_related_developmental_entity': 615, 'development_type_of': 616, 'related_developmental_entity_of': 617, 'transforms_from': 618, 'is_abnormality_of_gene': 619, 'has_procedure_context': 620, 'process_extends_to': 621, 'inherent_3d_shape_of': 622, 'has_specialty': 623, 'contained_in': 624, 'has_grade': 625, 'has_possibly_included_patient_type': 626, 'cell_type_or_tissue_affected_by_chemical_or_drug': 627, 'has_aggregation_view': 628, 'disease_excludes_cytogenetic_abnormality': 629, 'chromosome_involved_in_cytogenetic_abnormality': 630, 'gene_involved_in_molecular_abnormality': 631, 'endogenous_product_related_to': 632, 'approach_of': 633, 'anterosuperior_to': 634, 'has_add_on_code': 635, 'continuation_branch_of': 636, 'direct_site_of': 637, 'technique_of': 638, 'has_owning_subsection': 639, 'owning_section_of': 640, 'allelic_variant_of': 641, 'posteroinferior_to': 642, 'procedure_may_have_partially_excised_anatomy': 643, 'result_of': 644, 'has_filling': 645, 'classifies_class_code': 646, 'origin_of': 647, 'insertion_of': 648, 'efferent_to': 649, 'afferent_to': 650, 'product_monograph_title_of': 651, 'has_ingredients': 652, 'gene_is_biomarker_type': 653, 'replaced_by': 654, 'has_owning_affiliate': 655, 'replaces': 656, 'revision_status_of': 657, 'has_physiologic_state': 658, 'chemical_or_drug_plays_role_in_biological_process': 659, 'posteromedial_to': 660, 'full_grown_phenotype_of': 661, 'is_component_of_chemotherapy_regimen': 662, 'disease_has_accepted_treatment_with_regimen': 663, 'reformulation_of': 664, 'allele_has_abnormality': 665, 'disease_has_associated_gene': 666, 'eo_anatomy_is_associated_with_eo_disease': 667, 'chemical_or_drug_affects_abnormal_cell': 668, 'completely_excised_anatomy_may_have_procedure': 669, 'partially_excised_anatomy_may_have_procedure': 670, 'has_excluded_locale': 671, 'has_doseformgroup': 672, 'chromosomal_location_of_allele': 673, 'compositional_material_of': 674, 'inactive_ingredient_of': 675, 'active_moiety_of': 676, 'constitutes': 677, 'ingredients_of': 678, 'risk_factor_of': 679, 'may_be_normal_cell_origin_of_disease': 680, 'definitional_manifestation_of': 681, 'has_surgical_approach': 682, 'has_possibly_included_approach': 683, 'has_approach': 684, 'has_priority': 685, 'contraindicated_mechanism_of_action_of': 686, 'has_excluded_specimen': 687, 'absorbability_of': 688, 'has_actual_outcome': 689, 'has_expected_outcome': 690, 'modified_by': 691, 'has_time_aspect': 692, 'has_pharmaceutical_administration_method': 693, 'indirect_device_of': 694, 'procedure_device_of_possibly_included': 695, 'substance_used_by_possibly_included': 696, 'procedure_device_of_excluded': 697, 'pathology_of_excluded': 698, 'pathology_of_possibly_included': 699, 'has_excluded_associated_procedure': 700, 'procedure_site_of_excluded': 701, 'substance_used_by_excluded': 702, 'surgical_extent_of_possibly_included': 703, 'surgical_extent_of_excluded': 704, 'patient_type_of': 705, 'approach_of_possibly_included': 706, 'has_possibly_included_associated_procedure': 707, 'has_intent': 708, 'associated_finding_of_possibly_included': 709, 'panel_element_of': 710, 'specimen_of_excluded': 711, 'has_possibly_included_panel_element': 712, 'has_severity': 713, 'disease_is_grade': 714, 'eo_disease_has_property_or_attribute': 715, 'has_lateral_anatomic_location': 716, 'has_laterality': 717, 'has_surgical_extent': 718, 'is_approach_guidance_for': 719, 'has_inherent_3d_shape': 720, 'has_temporal_context': 721, 'is_sterile': 722, 'has_lateral_location_presence': 723, 'has_finding_context': 724, 'has_possibly_included_surgical_extent': 725, 'has_surface_texture': 726, 'has_time_modifier': 727, 'has_count': 728, 'has_excluded_approach': 729, 'has_scale_type': 730, 'inherent_location_of': 731, 'disease_has_molecular_abnormality': 732, 'disease_excludes_molecular_abnormality': 733, 'is_not_metastatic_anatomic_site_of_disease': 734, 'version_of': 735, 'has_origin': 736, 'has_insertion': 737, 'right_lateral_to': 738, 'right_medial_to': 739, 'left_lateral_to': 740, 'left_medial_to': 741, 'bounded_by': 742, 'has_part_anatomy_structure': 743, 'external_to': 744, 'has_continuation_branch': 745, 'has_lymphatic_drainage': 746, 'develops_from': 747, 'merges_with': 748, 'fuses_with': 749, 'has_pcdc_data_type': 750, 'has_inheritance_type': 751, 'may_be_qualified_by': 752, 'neoplasm_has_special_category': 753, 'has_excluded_component': 754, 'basis_of_strength_substance_of': 755, 'precise_active_ingredient_of': 756, 'concentration_strength_numerator_unit_of': 757, 'concentration_strength_denominator_unit_of': 758, 'doseformgroup_of': 759, 'has_product_monograph_title': 760, 'consists_of': 761, 'presentation_strength_numerator_unit_of': 762, 'presentation_strength_denominator_unit_of': 763, 'unit_of_presentation_of': 764, 'physiologic_state_of': 765, 'moved_from': 766, 'manufactured_by': 767, 'has_process_duration': 768, 'has_realization': 769, 'has_timing_of': 770, 'common_name_of': 771, 'has_common_name': 772, 'has_access': 773, 'included_in': 774, 'includes': 775, 'has_archetype': 776, 'has_loinc_number': 777, 'suffix_of': 778, 'divisor_of': 779, 'answer_to': 780, 'supersystem_of': 781, 'challenge_of': 782, 'adjustment_of': 783, 'count_of': 784, 'approach_of_excluded': 785, 'patient_type_of_excluded': 786, 'associated_finding_of_excluded': 787, 'route_of_administration_of_possibly_included': 788, 'component_of_possibly_included': 789, 'panel_element_of_possibly_included': 790, 'measurement_method_of': 791, 'component_of_excluded': 792, 'locale_of': 793, 'has_scale': 794, 'precondition_of': 795, 'specimen_source_morphology_of': 796, 'includes_sub_specimen': 797, 'has_presentation_strength_numerator_unit': 798, 'has_concentration_strength_numerator_unit': 799, 'has_concentration_strength_denominator_unit': 800, 'has_units': 801, 'has_episodicity': 802, 'has_revision_status': 803, 'has_excluded_surgical_extent': 804, 'disease_is_stage': 805, 'has_pathological_process': 806, 'specimen_source_identity_of': 807, 'has_ingredient_qualitative_strength': 808, 'has_supported_concept_property': 809, 'after': 810, 'icdc_value_of': 811, 'answer_to_is_sterile': 812, 'time_modifier_of': 813, 'inferolateral_to': 814, 'posterolateral_to': 815, 'inferomedial_to': 816, 'superomedial_to': 817, 'superolateral_to': 818, 'chemical_or_drug_is_metabolized_by_enzyme': 819, 'has_physiologic_effect': 820, 'contraindicated_physiologic_effect_of': 821, 'is_subject_of': 822, 'is_imaged_location_for': 823, 'has_exam': 824, 'is_modality_type_for': 825, 'has_connection_type': 826, 'abnormal_cell_affected_by_chemical_or_drug': 827, 'has_data_element': 828, 'allele_plays_altered_role_in_process': 829, 'treated_by': 830, 'forms': 831, 'has_full_grown_phenotype': 832, 'filling_of': 833, 'has_disposition': 834, 'quantified_form_of': 835, 'allele_in_chromosomal_location': 836, 'biomarker_type_includes_gene': 837, 'anatomical_entity_observed_in': 838, 'segmental_composition_of': 839, 'internal_to': 840, 'is_imaging_focus_of': 841, 'is_timing_for': 842, 'is_pharmaceutical_route_for': 843, 'is_given_pharmaceutical_substance_for': 844, 'is_modality_subtype_for': 845, 'is_presence_of_lateral_location': 846, 'is_lateral_anatomic_location_of': 847, 'is_aggregation_view_of': 848, 'has_object_guidance': 849, 'has_action_guidance': 850, 'has_presence_guidance': 851, 'is_view_type_for': 852, 'has_approach_guidance': 853, 'is_maneuver_type_for': 854, 'pathway_has_gene_element': 855, 'disease_is_marked_by_gene': 856, 'diagnoses': 857, 'treats': 858, 'degree_of': 859, 'has_degree': 860, 'has_version': 861, 'has_quantified_form': 862, 'negatively_regulated_by': 863, 'positively_regulated_by': 864, 'inverse_ends_during': 865, 'ends_during': 866, 'has_segmental_supply': 867, 'has_secondary_segmental_supply': 868, 'has_primary_segmental_supply': 869, 'has_segmental_composition': 870, 'germ_origin_of': 871, 'has_direct_cell_shape': 872, 'direct_cell_shape_of': 873, 'formed_by': 874, 'fusion_of': 875, 'inverse_relative_to': 876, 'extended_to_by_process': 877, 'coating_material_of': 878, 'may_qualify': 879, 'has_dose_form_intended_site': 880, 'imaged_anatomy_has_procedure': 881, 'inverse_happens_during': 882, 'part_referred_to_by': 883, 'is_physical_location_of_gene': 884, 'site_of_metabolism': 885, 'has_pharmacokinetics': 886, 'has_ctcae_5_parent': 887, 'ctcae_5_parent_of': 888, 'phenotype_of': 889, 'has_allelic_variant': 890, 'has_adjustment': 891, 'archetype_of': 892, 'human_disease_maps_to_eo_disease': 893, 'route_of_administration_of_excluded': 894, 'gene_product_sequence_variation_encoded_by_gene_mutant': 895, 'is_property_or_attribute_of_eo_disease': 896, 'eo_disease_has_associated_eo_anatomy': 897, 'gene_product_has_abnormality': 898, 'cell_type_is_associated_with_eo_disease': 899, 'gene_has_physical_location': 900, 'allele_absent_from_wild-type_chromosomal_location': 901, 'eo_disease_has_associated_cell_type': 902, 'has_possibly_included_route_of_administration': 903, 'has_excluded_route_of_administration': 904, 'has_dose_form_administration_method': 905, 'smaller_than': 906, 'larger_than': 907, 'value_set_is_paired_with': 908, 'units_of': 909, 'abnormality_associated_with_allele': 910, 'process_altered_by_allele': 911, 'activity_of_allele': 912, 'allele_has_activity': 913, 'kind_is_domain_of': 914, 'kind_is_range_of': 915, 'role_is_parent_of': 916, 'chemical_or_drug_metabolism_is_associated_with_allele': 917, 'chromosomal_location_of_wild-type_gene': 918, 'role_has_parent': 919, 'qualifier_applies_to': 920, 'is_qualified_by': 921, 'gene_mutant_encodes_gene_product_sequence_variation': 922, 'gene_product_has_gene_product_variant': 923, 'is_paired_with_value_set': 924, 'disease_may_have_normal_tissue_origin': 925, 'has_pharmaceutical_transformation': 926, 'has_dose_form_transformation': 927, 'is_presence_guidance_for': 928, 'grade_of': 929, 'episodicity_of': 930, 'place_traveled_from': 931, 'locale_of_excluded': 932, 'patient_type_of_possibly_included': 933, 'reformulated_to': 934, 'plays_role': 935, 'has_supported_concept_relationship': 936, 'has_owning_section': 937, 'data_element_of': 938, 'dependent_of': 939, 'has_dependent': 940, 'moved_to': 941, 'has_life_circumstance': 942, 'sub_specimen_included_by': 943, 'has_pharmaceutical_basic_dose_form': 944, 'has_developmental_stage': 945, 'corresponds_to': 946, 'matures_from': 947, 'has_fusion': 948, 'has_observed_anatomical_entity': 949, 'has_development_type': 950, 'has_presentation_strength_denominator_unit': 951, 'has_unit_of_presentation': 952, 'has_dose_form_release_characteristic': 953, 'surface_texture_of': 954, 'inc_parent_of': 955, 'has_inc_parent': 956, 'has_pharmaceutical_release_characteristics': 957, 'process_duration_of': 958, 'has_parent_group': 959, 'has_absorbability': 960, 'target_population_of': 961, 'ingredient_qualitative_strength_of': 962, 'has_target_population': 963, 'may_be_normal_tissue_origin_of_disease': 964}\n",
      "Processing UMLSDataset base dataset...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 20%|███████████████████████▍                                                                                           | 8928541/43842950 [02:08<06:51, 84924.09it/s]"
     ]
    }
   ],
   "source": [
    "umls_ds = UMLSDataset(\n",
    "    root=\"/data/pj20/umls/\",\n",
    "    # root=\"https://storage.googleapis.com/pyhealth/umls/\",\n",
    "    dev=False,\n",
    "    refresh_cache=False\n",
    ")\n",
    "\n",
    "# check the dataset statistics before setting task\n",
    "print(umls_ds.stat()) \n",
    "\n",
    "# check the relation numbers in the dataset\n",
    "print(\"Relations in KG:\", umls_ds.relation2id)\n",
    "\n",
    "umls_ds = umls_ds.set_task(link_prediction_fn, negative_sampling=64, save=False)\n",
    "\n",
    "model = TransE(\n",
    "    dataset=umls_ds,\n",
    "    e_dim=512, \n",
    "    r_dim=512, \n",
    ")\n",
    "\n",
    "print('Loaded model: ', model)\n",
    "state_dict = torch.load(\"/data/pj20/umls_kge/pretrained_model/umls_transe_new/1_250000_last.ckpt\")\n",
    "model.load_state_dict(state_dict)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "\n",
    "with open(\"/data/pj20/umls_kge/pretrained_model/umls_transe_new/model.pkl\", \"wb\") as f:\n",
    "    pickle.dump(model, f)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "len(model.R_emb), len(model.E_emb)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(\"/data/pj20/umls_kge/pretrained_model/umls_transe_new/E_emb.pkl\", \"wb\") as f:\n",
    "    pickle.dump(model.E_emb, f)\n",
    "\n",
    "with open(\"/data/pj20/umls_kge/pretrained_model/umls_transe_new/R_emb.pkl\", \"wb\") as f:\n",
    "    pickle.dump(model.R_emb, f)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Load Pre-trained Entity Embedding and Relation Embedding"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "\n",
    "with open(\"/data/pj20/umls_kge/pretrained_model/umls_transe_new/E_emb.pkl\", \"rb\") as f:\n",
    "    E_emb = pickle.load(f)\n",
    "\n",
    "with open(\"/data/pj20/umls_kge/pretrained_model/umls_transe_new/R_emb.pkl\", \"rb\") as f:\n",
    "    R_emb = pickle.load(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "\n",
    "with open(\"/data/pj20/umls_kge/pretrained_model/umls_transe_new/id2entity.json\", \"r\") as f:\n",
    "    id2entity = json.load(f)\n",
    "\n",
    "with open(\"/data/pj20/umls_kge/pretrained_model/umls_transe_new/id2relation.json\", \"r\") as f:\n",
    "    id2relation = json.load(f)\n",
    "\n",
    "entity2id = {v: k for k, v in id2entity.items()}\n",
    "relation2id = {v: k for k, v in id2relation.items()}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## ATC"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# load the mapping from ATC to UMLS\n",
    "atc_umls = pd.read_csv(\"../resource/ATC_to_UMLS.csv\", header=None)\n",
    "atc_umls_cuis = atc_umls[1].tolist()[1:]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Check if there are any ATC codes that are not in the UMLS\n",
    "\n",
    "cnt = 0\n",
    "not_covered_cui_atc = []\n",
    "for cui in tqdm(atc_umls_cuis):\n",
    "    if cui not in entity2id.keys():\n",
    "        not_covered_cui_atc.append(cui)\n",
    "        cnt+=1\n",
    "\n",
    "cnt, not_covered_cui_atc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# get the embeddings\n",
    "\n",
    "atc_to_umls = {}\n",
    "for i in tqdm(range(len(atc_umls))):\n",
    "    if atc_umls[1][i] != \"UMLS\":\n",
    "        atc_to_umls[atc_umls[0][i]] = atc_umls[1][i]\n",
    "\n",
    "atc_id2emb = {}\n",
    "for atc_id in tqdm(atc_to_umls.keys()):\n",
    "    atc_id2emb[atc_id] = E_emb[int(entity2id[atc_to_umls[atc_id]])].detach().numpy().tolist()\n",
    "\n",
    "with open(f\"../resource/embeddings/KG/drugs/atc.json\", \"w\") as f:\n",
    "    json.dump(atc_id2emb, f, indent=6)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from collections import defaultdict\n",
    "\n",
    "data = pd.read_csv('/home/pj20/PyHealth/pyhealth/medcode/resource/ATC.csv')\n",
    "\n",
    "atc_to_umls = {}\n",
    "for i in tqdm(range(len(atc_umls))):\n",
    "    if atc_umls[1][i] != \"UMLS\":\n",
    "        atc_to_umls[atc_umls[0][i]] = atc_umls[1][i]\n",
    "        \n",
    "atc_umls_dict = defaultdict(dict)\n",
    "\n",
    "for atc_id in tqdm(atc_to_umls.keys()):\n",
    "    atc_umls_dict[atc_id]['UMLS CUI'] = atc_to_umls[atc_id]\n",
    "    atc_umls_dict[atc_id]['UMLS-KG Embedding'] = E_emb[int(entity2id[atc_to_umls[atc_id]])].detach().numpy().tolist()\n",
    "    \n",
    "data['UMLS CUI'] = ''\n",
    "data['UMLS-KG Embedding'] = ''\n",
    "\n",
    "for index, row in data.iterrows():\n",
    "    code = row['code']\n",
    "    if code in atc_umls_dict:\n",
    "        data.at[index, 'UMLS CUI'] = atc_umls_dict[code]['UMLS CUI']\n",
    "        data.at[index, 'UMLS-KG Embedding'] = atc_umls_dict[code]['UMLS-KG Embedding']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "from tqdm import tqdm\n",
    "\n",
    "# Assuming other necessary variables like `data`, `E_emb`, and `entity2id` are already defined...\n",
    "data = pd.read_csv('/home/pj20/PyHealth/pyhealth/medcode/resource/ATC.csv')\n",
    "data = data.drop(columns=['description'])\n",
    "data = data.drop(columns=['indication'])\n",
    "# Define the output file paths\n",
    "\n",
    "embedding_file_path = '/home/pj20/PyHealth/pyhealth/medcode/resource/embeddings/KG/transe/detailed/ATC_Embedding.tsv'\n",
    "metadata_file_path = '/home/pj20/PyHealth/pyhealth/medcode/resource/embeddings/KG/transe/detailed/ATC_Metadata.tsv'\n",
    "\n",
    "# Create empty lists to store valid embeddings and metadata\n",
    "valid_embeddings = []\n",
    "valid_metadata = []\n",
    "\n",
    "# Loop through each row in your metadata DataFrame\n",
    "for _, row in tqdm(data.iterrows(), total=data.shape[0]):\n",
    "    # Get the ATC code from the current row\n",
    "    atc_id = row['code']\n",
    "    \n",
    "    # Check if ATC code has a corresponding UMLS CUI and embedding\n",
    "    if atc_id in atc_to_umls:\n",
    "        umls_cui = atc_to_umls[atc_id]\n",
    "        \n",
    "        # Get and format the embedding\n",
    "        embedding = E_emb[int(entity2id[umls_cui])].detach().numpy().tolist()\n",
    "        embedding_str = '\\t'.join(map(str, embedding))\n",
    "        \n",
    "        # Append the embedding and metadata to the respective lists\n",
    "        valid_embeddings.append([embedding_str])\n",
    "        \n",
    "        # Add UMLS CUI to the row before appending to valid_metadata\n",
    "        row_dict = row.to_dict()\n",
    "        row_dict['UMLS CUI'] = umls_cui\n",
    "        valid_metadata.append(row_dict)\n",
    "\n",
    "# Convert lists to DataFrames\n",
    "valid_embeddings_df = pd.DataFrame(valid_embeddings)\n",
    "valid_metadata_df = pd.DataFrame(valid_metadata)\n",
    "\n",
    "# Save the valid embeddings and metadata to TSV files\n",
    "valid_embeddings_df.to_csv(embedding_file_path, sep='\\t', index=False, header=False)\n",
    "valid_metadata_df.to_csv(metadata_file_path, sep='\\t', index=False, header=True)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "len(valid_embeddings_df), len(valid_metadata_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data.to_csv('/home/pj20/PyHealth/pyhealth/medcode/resource/embeddings/KG/transe/detailed/ATC.csv', index=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## ICD-9-CM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# load the mapping from ICD-9-CM to UMLS\n",
    "icd9cm_umls = pd.read_csv(\"../resource/ICD9CM_to_UMLS.csv\", header=None)\n",
    "icd9cm_umls_cuis = icd9cm_umls[1].tolist()[1:]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Check if there are any ICD-9-CM codes that are not in the UMLS\n",
    "\n",
    "cnt = 0\n",
    "not_covered_cui_icd9cm = []\n",
    "for cui in tqdm(icd9cm_umls_cuis):\n",
    "    if cui not in entity2id.keys():\n",
    "        not_covered_cui_icd9cm.append(cui)\n",
    "        cnt+=1\n",
    "\n",
    "cnt, not_covered_cui_icd9cm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# get the embeddings\n",
    "\n",
    "icd9cm_to_umls = {}\n",
    "for i in tqdm(range(len(icd9cm_umls))):\n",
    "    if icd9cm_umls[1][i] != \"UMLS\":\n",
    "        icd9cm_to_umls[icd9cm_umls[0][i]] = icd9cm_umls[1][i]\n",
    "\n",
    "icd9cm_id2emb = {}\n",
    "for icd9cm_id in tqdm(icd9cm_to_umls.keys()):\n",
    "    key = ICD9CM.standardize(icd9cm_id).replace('.', '')\n",
    "    icd9cm_id2emb[key] = E_emb[int(entity2id[icd9cm_to_umls[icd9cm_id]])].detach().numpy().tolist()\n",
    "\n",
    "with open(f\"../resource/embeddings/KG/conditions/icd9cm.json\", \"w\") as f:\n",
    "    json.dump(icd9cm_id2emb, f, indent=6)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from collections import defaultdict\n",
    "\n",
    "data = pd.read_csv('/home/pj20/PyHealth/pyhealth/medcode/resource/ICD9CM.csv')\n",
    "\n",
    "icd9cm_to_umls = {}\n",
    "for i in tqdm(range(len(icd9cm_umls))):\n",
    "    if icd9cm_umls[1][i] != \"UMLS\":\n",
    "        icd9cm_to_umls[icd9cm_umls[0][i]] = icd9cm_umls[1][i]\n",
    "        \n",
    "icd9cm_umls_dict = defaultdict(dict)\n",
    "\n",
    "for icd9cm_key in tqdm(icd9cm_to_umls.keys()):\n",
    "    icd9cm_umls_dict[icd9cm_key]['UMLS CUI'] = icd9cm_to_umls[icd9cm_key]\n",
    "    icd9cm_umls_dict[icd9cm_key]['UMLS-KG Embedding'] = E_emb[int(entity2id[icd9cm_to_umls[icd9cm_key]])].detach().numpy().tolist()\n",
    "    \n",
    "data['UMLS CUI'] = ''\n",
    "data['UMLS-KG Embedding'] = ''\n",
    "\n",
    "for index, row in data.iterrows():\n",
    "    code = row['code']\n",
    "    if code in icd9cm_umls_dict:\n",
    "        data.at[index, 'UMLS CUI'] = icd9cm_umls_dict[code]['UMLS CUI']\n",
    "        data.at[index, 'UMLS-KG Embedding'] = icd9cm_umls_dict[code]['UMLS-KG Embedding']\n",
    "\n",
    "data.to_csv('/home/pj20/PyHealth/pyhealth/medcode/resource/embeddings/KG/transe/detailed/ICD9CM.csv', index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "\n",
    "# Load your CSV\n",
    "data = pd.read_csv('/home/pj20/PyHealth/pyhealth/medcode/resource/embeddings/KG/transe/detailed/ICD9CM.csv')\n",
    "\n",
    "# 1. Create Embedding File\n",
    "# Extract and process the UMLS-KG Embedding\n",
    "embedding_data = data['UMLS-KG Embedding'].apply(lambda x: pd.Series(eval(x)))  # Using eval to convert string to list\n",
    "\n",
    "# Save to TSV without headers and index\n",
    "embedding_data.to_csv('/home/pj20/PyHealth/pyhealth/medcode/resource/embeddings/KG/transe/detailed/ICD9CM_Embedding.tsv', sep='\\t', index=False, header=False)\n",
    "\n",
    "# 2. Create Metadata File\n",
    "# Use all columns except 'UMLS-KG Embedding' as metadata\n",
    "metadata_data = data.drop(columns=['UMLS-KG Embedding'])\n",
    "\n",
    "# Save to TSV with headers and without index\n",
    "metadata_data.to_csv('/home/pj20/PyHealth/pyhealth/medcode/resource/embeddings/KG/transe/detailed/ICD9CM_Metadata.tsv', sep='\\t', index=False, header=True)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## ICD-9-PROC"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# load the mapping from ICD-9-proc to UMLS\n",
    "icd9proc_umls = pd.read_csv(\"../resource/ICD9CM_to_UMLS.csv\", header=None)\n",
    "icd9proc_umls_cuis = icd9proc_umls[1].tolist()[1:]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Check if there are any ICD-9-proc codes that are not in the UMLS\n",
    "\n",
    "cnt = 0\n",
    "not_covered_cui_icd9proc = []\n",
    "for cui in tqdm(icd9proc_umls_cuis):\n",
    "    if cui not in entity2id.keys():\n",
    "        not_covered_cui_icd9proc.append(cui)\n",
    "        cnt+=1\n",
    "\n",
    "cnt, not_covered_cui_icd9proc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# get the embeddings\n",
    "\n",
    "icd9proc_to_umls = {}\n",
    "for i in tqdm(range(len(icd9proc_umls))):\n",
    "    if icd9proc_umls[1][i] != \"UMLS\":\n",
    "        icd9proc_to_umls[icd9proc_umls[0][i]] = icd9proc_umls[1][i]\n",
    "\n",
    "icd9proc_id2emb = {}\n",
    "for icd9proc_id in tqdm(icd9proc_to_umls.keys()):\n",
    "    key = ICD9PROC.standardize(icd9proc_id).replace('.', '')\n",
    "    icd9proc_id2emb[key] = E_emb[int(entity2id[icd9proc_to_umls[icd9proc_id]])].detach().numpy().tolist()\n",
    "\n",
    "with open(f\"../resource/embeddings/KG/procedures/icd9proc.json\", \"w\") as f:\n",
    "    json.dump(icd9proc_id2emb, f, indent=6)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "icd9proc_to_umls"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from collections import defaultdict\n",
    "\n",
    "data = pd.read_csv('/home/pj20/PyHealth/pyhealth/medcode/resource/ICD9PROC.csv')\n",
    "\n",
    "icd9proc_to_umls = {}\n",
    "for i in tqdm(range(len(icd9proc_umls))):\n",
    "    if icd9proc_umls[1][i] != \"UMLS\":\n",
    "        icd9proc_to_umls[icd9proc_umls[0][i]] = icd9proc_umls[1][i]\n",
    "\n",
    "icd9proc_umls_dict = defaultdict(dict)\n",
    "\n",
    "for icd9proc_key in tqdm(icd9proc_to_umls.keys()):\n",
    "    icd9proc_umls_dict[icd9proc_key]['UMLS CUI'] = icd9proc_to_umls[icd9proc_key]\n",
    "    icd9proc_umls_dict[icd9proc_key]['UMLS-KG Embedding'] = E_emb[int(entity2id[icd9proc_to_umls[icd9proc_key]])].detach().numpy().tolist()\n",
    "    \n",
    "data['UMLS CUI'] = ''\n",
    "data['UMLS-KG Embedding'] = ''\n",
    "\n",
    "for index, row in data.iterrows():\n",
    "    code = row['code']\n",
    "    if code in icd9proc_umls_dict:\n",
    "        data.at[index, 'UMLS CUI'] = icd9proc_umls_dict[code]['UMLS CUI']\n",
    "        data.at[index, 'UMLS-KG Embedding'] = icd9proc_umls_dict[code]['UMLS-KG Embedding']\n",
    "\n",
    "data.to_csv('/home/pj20/PyHealth/pyhealth/medcode/resource/embeddings/KG/transe/detailed/ICD9PROC.csv', index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "from tqdm import tqdm\n",
    "\n",
    "# Assuming other necessary variables like `data`, `E_emb`, and `entity2id` are already defined...\n",
    "data = pd.read_csv('/home/pj20/PyHealth/pyhealth/medcode/resource/ICD9PROC.csv')\n",
    "# Define the output file paths\n",
    "\n",
    "embedding_file_path = '/home/pj20/PyHealth/pyhealth/medcode/resource/embeddings/KG/transe/detailed/ICD9PROC_Embedding.tsv'\n",
    "metadata_file_path = '/home/pj20/PyHealth/pyhealth/medcode/resource/embeddings/KG/transe/detailed/ICD9PROC_Metadata.tsv'\n",
    "\n",
    "# Create empty lists to store valid embeddings and metadata\n",
    "valid_embeddings = []\n",
    "valid_metadata = []\n",
    "\n",
    "# Loop through each row in your metadata DataFrame\n",
    "for _, row in tqdm(data.iterrows(), total=data.shape[0]):\n",
    "    # Get the ATC code from the current row\n",
    "    atc_id = row['code']\n",
    "    \n",
    "    # Check if ATC code has a corresponding UMLS CUI and embedding\n",
    "    if atc_id in icd9proc_to_umls:\n",
    "        umls_cui = icd9proc_to_umls[atc_id]\n",
    "        \n",
    "        # Get and format the embedding\n",
    "        embedding = E_emb[int(entity2id[umls_cui])].detach().numpy().tolist()\n",
    "        embedding_str = '\\t'.join(map(str, embedding)).replace('\\\"', '')\n",
    "        \n",
    "        # Append the embedding and metadata to the respective lists\n",
    "        valid_embeddings.append([embedding_str])\n",
    "        \n",
    "        # Add UMLS CUI to the row before appending to valid_metadata\n",
    "        row_dict = row.to_dict()\n",
    "        row_dict['UMLS CUI'] = umls_cui\n",
    "        valid_metadata.append(row_dict)\n",
    "\n",
    "# Convert lists to DataFrames\n",
    "valid_embeddings_df = pd.DataFrame(valid_embeddings)\n",
    "valid_metadata_df = pd.DataFrame(valid_metadata)\n",
    "\n",
    "# Save the valid embeddings and metadata to TSV files\n",
    "valid_embeddings_df.to_csv(embedding_file_path, sep='\\t', index=False, header=False)\n",
    "valid_metadata_df.to_csv(metadata_file_path, sep='\\t', index=False, header=True)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('/home/pj20/PyHealth/pyhealth/medcode/resource/embeddings/KG/transe/detailed/ATC_Embedding.tsv', 'r') as f:\n",
    "    lines = f.readlines()\n",
    "  \n",
    "lines_new = []  \n",
    "for line in lines:\n",
    "    lines_new.append(line.replace('\\\"', ''))\n",
    "    \n",
    "with open('/home/pj20/PyHealth/pyhealth/medcode/resource/embeddings/KG/transe/detailed/ATC_Embedding.tsv', 'w') as f:\n",
    "    f.writelines(lines_new)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## CCSCM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "icd9cm_to_ccscm = {}\n",
    "\n",
    "with open(\"../resource/ICD9CM_to_CCSCM.csv\", \"r\") as f:\n",
    "    reader = csv.reader(f)\n",
    "    for row in reader:\n",
    "        if row[1] != 'CCSCM':\n",
    "            icd9cm_to_ccscm[row[0]] = row[1]\n",
    "\n",
    "ccscm_to_icd9cm = defaultdict(list)\n",
    "for k, v in icd9cm_to_ccscm.items():\n",
    "    ccscm_to_icd9cm[v].append(k)\n",
    "\n",
    "ccscm_icd9cm = {}\n",
    "for k, v in ccscm_to_icd9cm.items():\n",
    "    ccscm_icd9cm[k] = v[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# get the embeddings\n",
    "ccscm_id2emb = {}\n",
    "for ccscm_id in tqdm(ccscm_icd9cm.keys()):\n",
    "    try:\n",
    "        ccscm_id2emb[ccscm_id] = E_emb[int(entity2id[icd9cm_to_umls[ccscm_icd9cm[ccscm_id]]])].detach().numpy().tolist()\n",
    "    except:\n",
    "        ccscm_id2emb[ccscm_id] = E_emb[int(entity2id[icd9cm_to_umls[ccscm_icd9cm[ccscm_id].replace('.00', '')]])].detach().numpy().tolist()\n",
    "\n",
    "with open(f\"../resource/embeddings/KG/conditions/ccscm.json\", \"w\") as f:\n",
    "    json.dump(ccscm_id2emb, f, indent=6)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## CCSPROC"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "icd9proc_to_ccsproc = {}\n",
    "\n",
    "with open(\"../resource/ICD9PROC_to_CCSPROC.csv\", \"r\") as f:\n",
    "    reader = csv.reader(f)\n",
    "    for row in reader:\n",
    "        if row[1] != 'CCSPROC':\n",
    "            icd9proc_to_ccsproc[row[0]] = row[1]\n",
    "\n",
    "ccsproc_to_icd9proc = defaultdict(list)\n",
    "for k, v in icd9proc_to_ccsproc.items():\n",
    "    ccsproc_to_icd9proc[v].append(k)\n",
    "\n",
    "ccsproc_icd9proc = {}\n",
    "for k, v in ccsproc_to_icd9proc.items():\n",
    "    ccsproc_icd9proc[k] = v[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# get the embeddings\n",
    "ccsproc_id2emb = {}\n",
    "for ccsproc_id in tqdm(ccsproc_icd9proc.keys()):\n",
    "    try:\n",
    "        ccsproc_id2emb[ccsproc_id] = E_emb[int(entity2id[icd9proc_to_umls[ccsproc_icd9proc[ccsproc_id]]])].detach().numpy().tolist()\n",
    "    except:\n",
    "        try:\n",
    "            icd9procid = ccsproc_icd9proc[ccsproc_id]\n",
    "            if icd9procid[0] == '0':\n",
    "                icd9procid = icd9procid[1:]\n",
    "            if icd9procid[-1] == '0':\n",
    "                icd9procid = icd9procid[:-1]\n",
    "            if icd9procid[-1] == '0':\n",
    "                icd9procid = icd9procid[:-2]\n",
    "            if icd9procid[-1] == '.':\n",
    "                icd9procid = icd9procid[:-1]\n",
    "\n",
    "            ccsproc_id2emb[ccsproc_id] = E_emb[int(entity2id[icd9proc_to_umls[icd9procid]])].detach().numpy().tolist()\n",
    "            \n",
    "        except:\n",
    "            icd9procid = icd9procid[:-1]\n",
    "            ccsproc_id2emb[ccsproc_id] = E_emb[int(entity2id[icd9proc_to_umls[icd9procid]])].detach().numpy().tolist()\n",
    "\n",
    "with open(f\"../resource/embeddings/KG/procedures/ccsproc.json\", \"w\") as f:\n",
    "    json.dump(ccsproc_id2emb, f, indent=6)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Special Tokens"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "special_tokens = {}\n",
    "tokens = ['<pad>', '<unk>']\n",
    "\n",
    "for token in tokens:\n",
    "    special_tokens[token] = np.random.randn(512).tolist()\n",
    "\n",
    "with open(f\"../resource/embeddings/KG/special_tokens/special_tokens.json\", \"w\") as f:\n",
    "    json.dump(special_tokens, f, indent=6)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  },
  "vscode": {
   "interpreter": {
    "hash": "79cb95e61c4f960f4e102f21c45668d32cb5c494b237694c15d64b50342e6e99"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
